# https://github.com/jacobgil/pytorch-grad-cam
!pip install --quiet grad-cam
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from PIL import Image
import os
from tqdm import tqdm
from collections import defaultdict
from random import sample
%config InlineBackend.figure_format = 'jpg'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
cuda:0
transform = transforms.Compose([
transforms.ToTensor(),
])
import torch.nn as nn
import torch.nn.functional as F
class LongcatNet(nn.Module):
def __init__(self):
super().__init__()
self.bn1 = nn.BatchNorm2d(3)
self.conv1 = nn.Conv2d(3, 9, 3)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2_bn = nn.BatchNorm2d(9)
self.conv2 = nn.Conv2d(9, 16, 3)
self.pool2 = nn.MaxPool2d(2, 2)
self.conv3_bn = nn.BatchNorm2d(16)
self.conv3 = nn.Conv2d(16, 25, 3)
self.pool3 = nn.MaxPool2d(2, 2)
self.conv4_bn = nn.BatchNorm2d(25)
self.conv4 = nn.Conv2d(25, 36, 3)
self.pool4 = nn.MaxPool2d(2, 2)
self.conv5_bn = nn.BatchNorm2d(36)
self.conv5 = nn.Conv2d(36, 36, 3)
self.conv6_bn = nn.BatchNorm2d(36)
self.conv6 = nn.Conv2d(36, 49, 3)
self.conv7_bn = nn.BatchNorm2d(49)
self.conv7 = nn.Conv2d(49, 49, 3)
self.conv8_bn = nn.BatchNorm2d(49)
self.conv8 = nn.Conv2d(49, 49, 3)
self.conv9_bn = nn.BatchNorm2d(49)
self.conv9 = nn.Conv2d(49, 49, 3)
self.pool9 = nn.MaxPool2d(2, 2)
self.conv10_bn = nn.BatchNorm2d(49)
self.conv10 = nn.Conv2d(49, 49, 3)
self.pool10 = nn.MaxPool2d(2, 2)
self.fc = nn.Linear(1764, 4)
def forward(self, x):
x = self.bn1(x)
x = self.conv2_bn(self.pool1(F.relu(self.conv1(x))))
x = self.conv3_bn(self.pool2(F.relu(self.conv2(x))))
x = self.conv4_bn(self.pool3(F.relu(self.conv3(x))))
x = self.conv5_bn(self.pool4(F.relu(self.conv4(x))))
x = self.conv6_bn(F.relu(self.conv5(x)))
x = self.conv7_bn(F.relu(self.conv6(x)))
x = self.conv8_bn(F.relu(self.conv7(x)))
x = self.conv9_bn(F.relu(self.conv8(x)))
x = self.conv10_bn(self.pool9(F.relu(self.conv9(x))))
x = self.pool10(F.relu(self.conv10(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = self.fc(x)
return x
longcat = LongcatNet()
longcat = longcat.to(device)
longcat.load_state_dict(torch.load('saved_models/longcat/epoch_7_batch_5000.pth', map_location=device))
longcat.to(device).eval()
LongcatNet( (bn1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv1): Conv2d(3, 9, kernel_size=(3, 3), stride=(1, 1)) (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv2_bn): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv2): Conv2d(9, 16, kernel_size=(3, 3), stride=(1, 1)) (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv3_bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv3): Conv2d(16, 25, kernel_size=(3, 3), stride=(1, 1)) (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv4_bn): BatchNorm2d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv4): Conv2d(25, 36, kernel_size=(3, 3), stride=(1, 1)) (pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv5_bn): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv5): Conv2d(36, 36, kernel_size=(3, 3), stride=(1, 1)) (conv6_bn): BatchNorm2d(36, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv6): Conv2d(36, 49, kernel_size=(3, 3), stride=(1, 1)) (conv7_bn): BatchNorm2d(49, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv7): Conv2d(49, 49, kernel_size=(3, 3), stride=(1, 1)) (conv8_bn): BatchNorm2d(49, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv8): Conv2d(49, 49, kernel_size=(3, 3), stride=(1, 1)) (conv9_bn): BatchNorm2d(49, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv9): Conv2d(49, 49, kernel_size=(3, 3), stride=(1, 1)) (pool9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv10_bn): BatchNorm2d(49, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv10): Conv2d(49, 49, kernel_size=(3, 3), stride=(1, 1)) (pool10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (fc): Linear(in_features=1764, out_features=4, bias=True) )
class_codes = {
'Amsterdam' : 0,
'Firenca' : 1,
'LasVegas' : 2,
'NYC' : 3
}
def get_classifications(source):
"""
Grozan i neefikasan način za napraviti ovo. Ali ovo pišem u 04:07, i nekada
čovjek nadomjesti mentalnu energiju viškom računalne.
"""
filenames = defaultdict(list)
files = os.listdir(f'dataset/test/{source}')
for fname in tqdm(files):
if fname[-1] == 't':
continue
img = Image.open(f'dataset/test/{source}/' + fname)
img = transform(img)
img = torch.reshape(img, (1, 3, 640, 640))
outputs = longcat(img.cuda())
_, predicted = torch.max(outputs, 1)
for target in ['Amsterdam', 'Firenca', 'LasVegas', 'NYC']:
if predicted[0] == class_codes[target]:
filenames[f'to_{target}'].append(fname)
return filenames
lv_which_mapped_to = get_classifications('LasVegas')
100%|██████████| 148948/148948 [15:02<00:00, 164.99it/s]
len(lv_which_mapped_to['to_LasVegas'])
61393
len(lv_which_mapped_to['to_Amsterdam']), len(lv_which_mapped_to['to_Firenca']), len(lv_which_mapped_to['to_NYC']),
(3375, 2137, 7569)
for key in lv_which_mapped_to:
lst = lv_which_mapped_to[key]
shuffle(lst)
lv_which_mapped_to[key] = lst
lv_which_mapped_to[key] = [lv_which_mapped_to[key][x:x+4]
for x in range(0, len(lv_which_mapped_to[key]), 4)]
lv_which_mapped_to[key] = iter(lv_which_mapped_to[key])
def plot_grid(images, target):
n_row = 4
n_col = 2
fig, axs = plt.subplots(n_row, n_col, figsize=(15, 30))
axs = axs.flatten()
for ix, image in enumerate(images):
if ix > 1:
ix = ix + 2
image = Image.open('dataset/test/LasVegas/' + image)
axs[ix].imshow(image)
axs[ix].set_xticklabels([])
axs[ix].set_yticklabels([])
axs[ix].set_xticks([])
axs[ix].set_yticks([])
axs[ix].margins(x=0, y=0, tight=True)
for ix, image in enumerate(images):
if ix > 1:
ix = ix + 2
image = Image.open('dataset/test/LasVegas/' + image)
image = transform(image)
image = torch.reshape(image, (1, 3, 640, 640))
cam = GradCAM(model=longcat, target_layers=target_layers, use_cuda=True)
targets = [ClassifierOutputTarget(class_codes[target])]
grayscale_cam = cam(input_tensor=image, targets=targets)
image = image.permute(0, 2, 3, 1)
image = torch.reshape(image, (640, 640, 3))
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(image.numpy(), grayscale_cam, use_rgb=True)
axs[2 + ix].imshow(visualization)
axs[2 + ix].set_xticklabels([])
axs[2 + ix].set_yticklabels([])
axs[2 + ix].set_xticks([])
axs[2 + ix].set_yticks([])
axs[2 + ix].margins(x=0, y=0, tight=True)
plt.subplots_adjust(wspace=0, hspace=0)
plt.axis('off')
plt.show()
images = next(lv_which_mapped_to['to_LasVegas'])
plot_grid(images, 'LasVegas')
images = next(lv_which_mapped_to['to_LasVegas'])
plot_grid(images, 'LasVegas')
images = next(lv_which_mapped_to['to_LasVegas'])
plot_grid(images, 'LasVegas')
images = next(lv_which_mapped_to['to_LasVegas'])
plot_grid(images, 'LasVegas')
images = next(lv_which_mapped_to['to_LasVegas'])
plot_grid(images, 'LasVegas')
images = next(lv_which_mapped_to['to_LasVegas'])
plot_grid(images, 'LasVegas')
images = next(lv_which_mapped_to['to_LasVegas'])
plot_grid(images, 'LasVegas')
images = next(lv_which_mapped_to['to_LasVegas'])
plot_grid(images, 'LasVegas')
images = next(lv_which_mapped_to['to_LasVegas'])
plot_grid(images, 'LasVegas')
images = next(lv_which_mapped_to['to_LasVegas'])
plot_grid(images, 'LasVegas')
images = next(lv_which_mapped_to['to_LasVegas'])
plot_grid(images, 'LasVegas')
images = next(lv_which_mapped_to['to_LasVegas'])
plot_grid(images, 'LasVegas')
images = next(lv_which_mapped_to['to_NYC'])
plot_grid(images, 'NYC')
images = next(lv_which_mapped_to['to_NYC'])
plot_grid(images, 'NYC')
images = next(lv_which_mapped_to['to_NYC'])
plot_grid(images, 'NYC')
images = next(lv_which_mapped_to['to_NYC'])
plot_grid(images, 'NYC')
images = next(lv_which_mapped_to['to_NYC'])
plot_grid(images, 'NYC')